{ "cells": [ { "cell_type": "markdown", "id": "cMny8Ri7RvqC", "metadata": { "id": "cMny8Ri7RvqC" }, "source": [ "\n", "### **2. T-learner**\n", "The second learner is called T-learner, which denotes ``two learners\". Instead of fitting a single model to estimate the potential outcomes under both treatment and control groups, T-learner aims to learn different models for $\\mathbb{E}[R(1)|S]$ and $\\mathbb{E}[R(0)|S]$ separately, and finally combines them to obtain a final HTE estimator.\n", "\n", "Define the control response function as $\\mu_0(s)=\\mathbb{E}[R(0)|S=s]$, and the treatment response function as $\\mu_1(s)=\\mathbb{E}[R(1)|S=s]$. The algorithm of T-learner is summarized below:\n", "\n", "**Step 1:** Estimate $\\mu_0(s)$ and $\\mu_1(s)$ separately with any regression algorithms or supervised machine learning methods;\n", "\n", "**Step 2:** Estimate HTE by \n", "\\begin{equation*}\n", "\\hat{\\tau}_{\\text{T-learner}}(s)=\\hat\\mu_1(s)-\\hat\\mu_0(s).\n", "\\end{equation*}\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "eRpP5k9MBtzO", "metadata": { "id": "eRpP5k9MBtzO" }, "outputs": [], "source": [ "# import related packages\n", "import numpy as np\n", "import pandas as pd\n", "from matplotlib import pyplot as plt;\n", "from sklearn.ensemble import GradientBoostingRegressor\n", "from sklearn.linear_model import LinearRegression\n", "from causaldm.learners.CEL.Single_Stage import _env_getdata_CEL" ] }, { "cell_type": "markdown", "id": "XUu695Qrf61-", "metadata": { "id": "XUu695Qrf61-" }, "source": [ "### MovieLens Data" ] }, { "cell_type": "code", "execution_count": 2, "id": "JhfJntzcVVy2", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "executionInfo": { "elapsed": 288, "status": "ok", "timestamp": 1676750101543, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "JhfJntzcVVy2", "outputId": "7fab8a7a-7cd9-445c-a005-9a6d1994a071" }, "outputs": [ { "data": { "text/html": [ "
\n", " | user_id | \n", "movie_id | \n", "rating | \n", "age | \n", "Drama | \n", "Sci-Fi | \n", "gender_M | \n", "occupation_academic/educator | \n", "occupation_college/grad student | \n", "occupation_executive/managerial | \n", "occupation_other | \n", "occupation_technician/engineer | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "48.0 | \n", "1193.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "48.0 | \n", "919.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
2 | \n", "48.0 | \n", "527.0 | \n", "5.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
3 | \n", "48.0 | \n", "1721.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
4 | \n", "48.0 | \n", "150.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
65637 | \n", "5878.0 | \n", "3300.0 | \n", "2.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65638 | \n", "5878.0 | \n", "1391.0 | \n", "1.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65639 | \n", "5878.0 | \n", "185.0 | \n", "4.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65640 | \n", "5878.0 | \n", "2232.0 | \n", "1.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65641 | \n", "5878.0 | \n", "426.0 | \n", "3.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65642 rows × 12 columns
\n", "